From a8a60f43c56597259558261353b5bf7e953eed36 Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:42:43 -0700 Subject: [PATCH] Decode main loop improvements - Rearrange main decoding loops to handle chunks of 32 bytes at a time, then 4 bytes at a time, meaning that `decode_suffix` need only handle 0-4 bytes, simplifying its code. Moderate speed gains of around 5-10%. - Improve error precision. `InvalidLength` now has a `usize` length indicating how many valid symbols were found, but that the count of those symbols was invalid. Before, it just did `input % 4 == `, which was harder to reason about, as there might be padding etc. DecoderReader now also precisely reports the suitable InvalidByte if an earlier block of decoding found padding that was valid in that context, but more padding was found later, rendering that earlier padding invalid. - Tidy up decode tests. There were some duplicated scenarios, and certain aspects are now tested in more detail. --- src/decode.rs | 26 +- src/engine/general_purpose/decode.rs | 366 +++++------- src/engine/general_purpose/decode_suffix.rs | 76 +-- src/engine/naive.rs | 13 +- src/engine/tests.rs | 625 +++++++------------- src/read/decoder.rs | 70 ++- 6 files changed, 466 insertions(+), 710 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 5230fd3..0f66c74 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -9,18 +9,20 @@ use std::error; #[derive(Clone, Debug, PartialEq, Eq)] pub enum DecodeError { /// An invalid byte was found in the input. The offset and offending byte are provided. - /// Padding characters (`=`) interspersed in the encoded form will be treated as invalid bytes. + /// + /// Padding characters (`=`) interspersed in the encoded form are invalid, as they may only + /// be present as the last 0-2 bytes of input. + /// + /// This error may also indicate that extraneous trailing input bytes are present, causing + /// otherwise valid padding to no longer be the last bytes of input. InvalidByte(usize, u8), - /// The length of the input is invalid. - /// A typical cause of this is stray trailing whitespace or other separator bytes. - /// In the case where excess trailing bytes have produced an invalid length *and* the last byte - /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte` - /// will be emitted instead of `InvalidLength` to make the issue easier to debug. - InvalidLength, + /// The length of the input, as measured in valid base64 symbols, is invalid. + /// There must be 2-4 symbols in the last input quad. + InvalidLength(usize), /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. /// This is indicative of corrupted or truncated Base64. - /// Unlike `InvalidByte`, which reports symbols that aren't in the alphabet, this error is for - /// symbols that are in the alphabet but represent nonsensical encodings. + /// Unlike [DecodeError::InvalidByte], which reports symbols that aren't in the alphabet, + /// this error is for symbols that are in the alphabet but represent nonsensical encodings. InvalidLastSymbol(usize, u8), /// The nature of the padding was not as configured: absent or incorrect when it must be /// canonical, or present when it must be absent, etc. @@ -30,8 +32,10 @@ pub enum DecodeError { impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Self::InvalidByte(index, byte) => write!(f, "Invalid byte {}, offset {}.", byte, index), - Self::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), + Self::InvalidByte(index, byte) => { + write!(f, "Invalid symbol {}, offset {}.", byte, index) + } + Self::InvalidLength(len) => write!(f, "Invalid input length: {}", len), Self::InvalidLastSymbol(index, byte) => { write!(f, "Invalid last symbol {}, offset {}.", byte, index) } diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 21a386f..31c289e 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -3,45 +3,25 @@ use crate::{ DecodeError, PAD_BYTE, }; -// decode logic operates on chunks of 8 input bytes without padding -const INPUT_CHUNK_LEN: usize = 8; -const DECODED_CHUNK_LEN: usize = 6; - -// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last -// 2 bytes of any output u64 should not be counted as written to (but must be available in a -// slice). -const DECODED_CHUNK_SUFFIX: usize = 2; - -// how many u64's of input to handle at a time -const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; - -const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; - -// includes the trailing 2 bytes for the final u64 write -const DECODED_BLOCK_LEN: usize = - CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; - #[doc(hidden)] pub struct GeneralPurposeEstimate { - /// Total number of decode chunks, including a possibly partial last chunk - num_chunks: usize, - decoded_len_estimate: usize, + rem: usize, + conservative_len: usize, } impl GeneralPurposeEstimate { pub(crate) fn new(encoded_len: usize) -> Self { - // Formulas that won't overflow + let rem = encoded_len % 4; Self { - num_chunks: encoded_len / INPUT_CHUNK_LEN - + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, - decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, + rem, + conservative_len: (encoded_len / 4 + (rem > 0) as usize) * 3, } } } impl DecodeEstimate for GeneralPurposeEstimate { fn decoded_len_estimate(&self) -> usize { - self.decoded_len_estimate + self.conservative_len } } @@ -59,264 +39,237 @@ pub(crate) fn decode_helper( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { - let remainder_len = input.len() % INPUT_CHUNK_LEN; - - // Because the fast decode loop writes in groups of 8 bytes (unrolled to - // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of - // which only 6 are valid data), we need to be sure that we stop using the fast decode loop - // soon enough that there will always be 2 more bytes of valid data written after that loop. - let trailing_bytes_to_skip = match remainder_len { - // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, - // and the fast decode logic cannot handle padding - 0 => INPUT_CHUNK_LEN, - // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte - 1 | 5 => { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } - } - - return Err(DecodeError::InvalidLength); + // detect a trailing invalid byte, like a newline, as a user convenience + if estimate.rem == 1 { + let last_byte = input[input.len() - 1]; + // exclude pad bytes; might be part of padding that extends from earlier in the input + if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); } - // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes - // written by the fast decode loop. So, we have to ignore both these 2 bytes and the - // previous chunk. - 2 => INPUT_CHUNK_LEN + 2, - // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this - // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail - // with an error, not panic from going past the bounds of the output slice, so we let it - // use stage 3 + 4. - 3 => INPUT_CHUNK_LEN + 3, - // This can also decode to one output byte because it may be 2 input chars + 2 padding - // chars, which would decode to 1 byte. - 4 => INPUT_CHUNK_LEN + 4, - // Everything else is a legal decode len (given that we don't require padding), and will - // decode to at least 2 bytes of output. - _ => remainder_len, - }; - - // rounded up to include partial chunks - let mut remaining_chunks = estimate.num_chunks; - - let mut input_index = 0; - let mut output_index = 0; + } + // skip last quad, even if it's complete, as it may have padding + let input_complete_nonterminal_quads_len = input + .len() + .saturating_sub(estimate.rem) + // if rem was 0, subtract 4 to avoid padding + .saturating_sub((estimate.rem == 0) as usize * 4); + debug_assert!( + input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) + ); + + const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; + const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; + + let input_complete_quads_after_unrolled_chunks_len = + input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; + + let input_unrolled_loop_len = + input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; + + // chunks of 32 bytes + for (chunk_index, chunk) in input[..input_unrolled_loop_len] + .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) + .enumerate() { - let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); - - // Fast loop, stage 1 - // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { - while input_index <= max_start_index { - let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; - let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; - - decode_chunk( - &input_slice[0..], - input_index, - decode_table, - &mut output_slice[0..], - )?; - decode_chunk( - &input_slice[8..], - input_index + 8, - decode_table, - &mut output_slice[6..], - )?; - decode_chunk( - &input_slice[16..], - input_index + 16, - decode_table, - &mut output_slice[12..], - )?; - decode_chunk( - &input_slice[24..], - input_index + 24, - decode_table, - &mut output_slice[18..], - )?; - - input_index += INPUT_BLOCK_LEN; - output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; - remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; - } - } - - // Fast loop, stage 2 (aka still pretty fast loop) - // 8 bytes at a time for whatever we didn't do in stage 1. - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { - while input_index < max_start_index { - decode_chunk( - &input[input_index..(input_index + INPUT_CHUNK_LEN)], - input_index, - decode_table, - &mut output - [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], - )?; - - output_index += DECODED_CHUNK_LEN; - input_index += INPUT_CHUNK_LEN; - remaining_chunks -= 1; - } - } - } + let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; + let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE + ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; - // Stage 3 - // If input length was such that a chunk had to be deferred until after the fast loop - // because decoding it would have produced 2 trailing bytes that wouldn't then be - // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 - // trailing bytes. - // However, we still need to avoid the last chunk (partial or complete) because it could - // have padding, so we always do 1 fewer to avoid the last chunk. - for _ in 1..remaining_chunks { - decode_chunk_precise( - &input[input_index..], + decode_chunk_8( + &chunk[0..8], input_index, decode_table, - &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], + &mut chunk_output[0..6], + )?; + decode_chunk_8( + &chunk[8..16], + input_index + 8, + decode_table, + &mut chunk_output[6..12], + )?; + decode_chunk_8( + &chunk[16..24], + input_index + 16, + decode_table, + &mut chunk_output[12..18], + )?; + decode_chunk_8( + &chunk[24..32], + input_index + 24, + decode_table, + &mut chunk_output[18..24], )?; - - input_index += INPUT_CHUNK_LEN; - output_index += DECODED_CHUNK_LEN; } - // always have one more (possibly partial) block of 8 input - debug_assert!(input.len() - input_index > 1 || input.is_empty()); - debug_assert!(input.len() - input_index <= 8); + // remaining quads, except for the last possibly partial one, as it may have padding + let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; + let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; + { + let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; + + for (chunk_index, chunk) in input + [input_unrolled_loop_len..input_complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; + + decode_chunk_4( + chunk, + input_unrolled_loop_len + chunk_index * 4, + decode_table, + chunk_output, + )?; + } + } super::decode_suffix::decode_suffix( input, - input_index, + input_complete_nonterminal_quads_len, output, - output_index, + output_complete_quad_len, decode_table, decode_allow_trailing_bits, padding_mode, ) } -/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the -/// first 6 of those contain meaningful data. +/// Decode 8 bytes of input into 6 bytes of output. /// -/// `input` is the bytes to decode, of which the first 8 bytes will be processed. +/// `input` is the 8 bytes to decode. /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors /// accurately) /// `decode_table` is the lookup table for the particular base64 alphabet. -/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded -/// data. +/// `output` will have its first 6 bytes overwritten // yes, really inline (worth 30-50% speedup) #[inline(always)] -fn decode_chunk( +fn decode_chunk_8( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let morsel = decode_table[input[0] as usize]; + let morsel = decode_table[usize::from(input[0])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); } - let mut accum = (morsel as u64) << 58; + let mut accum = u64::from(morsel) << 58; - let morsel = decode_table[input[1] as usize]; + let morsel = decode_table[usize::from(input[1])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 1, input[1], )); } - accum |= (morsel as u64) << 52; + accum |= u64::from(morsel) << 52; - let morsel = decode_table[input[2] as usize]; + let morsel = decode_table[usize::from(input[2])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 2, input[2], )); } - accum |= (morsel as u64) << 46; + accum |= u64::from(morsel) << 46; - let morsel = decode_table[input[3] as usize]; + let morsel = decode_table[usize::from(input[3])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 3, input[3], )); } - accum |= (morsel as u64) << 40; + accum |= u64::from(morsel) << 40; - let morsel = decode_table[input[4] as usize]; + let morsel = decode_table[usize::from(input[4])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 4, input[4], )); } - accum |= (morsel as u64) << 34; + accum |= u64::from(morsel) << 34; - let morsel = decode_table[input[5] as usize]; + let morsel = decode_table[usize::from(input[5])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 5, input[5], )); } - accum |= (morsel as u64) << 28; + accum |= u64::from(morsel) << 28; - let morsel = decode_table[input[6] as usize]; + let morsel = decode_table[usize::from(input[6])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 6, input[6], )); } - accum |= (morsel as u64) << 22; + accum |= u64::from(morsel) << 22; - let morsel = decode_table[input[7] as usize]; + let morsel = decode_table[usize::from(input[7])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 7, input[7], )); } - accum |= (morsel as u64) << 16; + accum |= u64::from(morsel) << 16; - write_u64(output, accum); + output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); Ok(()) } -/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 -/// trailing garbage bytes. -#[inline] -fn decode_chunk_precise( +/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. +#[inline(always)] +fn decode_chunk_4( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let mut tmp_buf = [0_u8; 8]; + let morsel = decode_table[usize::from(input[0])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = u32::from(morsel) << 26; - decode_chunk( - input, - index_at_start_of_input, - decode_table, - &mut tmp_buf[..], - )?; + let morsel = decode_table[usize::from(input[1])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= u32::from(morsel) << 20; + + let morsel = decode_table[usize::from(input[2])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= u32::from(morsel) << 14; + + let morsel = decode_table[usize::from(input[3])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= u32::from(morsel) << 8; - output[0..6].copy_from_slice(&tmp_buf[0..6]); + output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); Ok(()) } -#[inline] -fn write_u64(output: &mut [u8], value: u64) { - output[..8].copy_from_slice(&value.to_be_bytes()); -} - #[cfg(test)] mod tests { use super::*; @@ -324,37 +277,36 @@ mod tests { use crate::engine::general_purpose::STANDARD; #[test] - fn decode_chunk_precise_writes_only_6_bytes() { + fn decode_chunk_8_writes_only_6_bytes() { let input = b"Zm9vYmFy"; // "foobar" let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; - decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); } #[test] - fn decode_chunk_writes_8_bytes() { - let input = b"Zm9vYmFy"; // "foobar" - let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + fn decode_chunk_4_writes_only_3_bytes() { + let input = b"Zm9v"; // "foobar" + let mut output = [0_u8, 1, 2, 3]; - decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); - assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); + decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', 3], &output); } #[test] fn estimate_short_lengths() { - for (range, (num_chunks, decoded_len_estimate)) in [ - (0..=0, (0, 0)), - (1..=4, (1, 3)), - (5..=8, (1, 6)), - (9..=12, (2, 9)), - (13..=16, (2, 12)), - (17..=20, (3, 15)), + for (range, decoded_len_estimate) in [ + (0..=0, 0), + (1..=4, 3), + (5..=8, 6), + (9..=12, 9), + (13..=16, 12), + (17..=20, 15), ] { for encoded_len in range { let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!(num_chunks, estimate.num_chunks); - assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); + assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); } } } @@ -369,15 +321,7 @@ mod tests { let len_128 = encoded_len as u128; let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!( - ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) - as usize, - estimate.num_chunks - ); - assert_eq!( - ((len_128 + 3) / 4 * 3) as usize, - estimate.decoded_len_estimate - ); + assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_len as u128); }) } } diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index 9fbb0d5..3d52ae5 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -3,7 +3,7 @@ use crate::{ DecodeError, PAD_BYTE, }; -/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided +/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided /// parameters. /// /// Returns the decode metadata representing the total number of bytes decoded, including the ones @@ -17,16 +17,18 @@ pub(crate) fn decode_suffix( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { + debug_assert!((input.len() - input_index) <= 4); + // Decode any leftovers that might not be a complete input chunk of 8 bytes. // Use a u64 as a stack-resident 8 byte buffer. let mut morsels_in_leftover = 0; - let mut padding_bytes = 0; - let mut first_padding_index: usize = 0; + let mut padding_bytes_count = 0; + // offset from input_index + let mut first_padding_offset: usize = 0; let mut last_symbol = 0_u8; - let start_of_leftovers = input_index; - let mut morsels = [0_u8; 8]; + let mut morsels = [0_u8; 4]; - for (i, &b) in input[start_of_leftovers..].iter().enumerate() { + for (leftover_index, &b) in input[input_index..].iter().enumerate() { // '=' padding if b == PAD_BYTE { // There can be bad padding bytes in a few ways: @@ -41,30 +43,30 @@ pub(crate) fn decode_suffix( // Per config, non-canonical but still functional non- or partially-padded base64 // may be treated as an error condition. - if i % 4 < 2 { + if leftover_index < 2 { // Check for case #2. - let bad_padding_index = start_of_leftovers - + if padding_bytes > 0 { + let bad_padding_index = input_index + + if padding_bytes_count > 0 { // If we've already seen padding, report the first padding index. // This is to be consistent with the normal decode logic: it will report an // error on the first padding character (since it doesn't expect to see // anything but actual encoded data). // This could only happen if the padding started in the previous quad since - // otherwise this case would have been hit at i % 4 == 0 if it was the same + // otherwise this case would have been hit at i == 4 if it was the same // quad. - first_padding_index + first_padding_offset } else { // haven't seen padding before, just use where we are now - i + leftover_index }; return Err(DecodeError::InvalidByte(bad_padding_index, b)); } - if padding_bytes == 0 { - first_padding_index = i; + if padding_bytes_count == 0 { + first_padding_offset = leftover_index; } - padding_bytes += 1; + padding_bytes_count += 1; continue; } @@ -72,9 +74,9 @@ pub(crate) fn decode_suffix( // To make '=' handling consistent with the main loop, don't allow // non-suffix '=' in trailing chunk either. Report error as first // erroneous padding. - if padding_bytes > 0 { + if padding_bytes_count > 0 { return Err(DecodeError::InvalidByte( - start_of_leftovers + first_padding_index, + input_index + first_padding_offset, PAD_BYTE, )); } @@ -85,22 +87,31 @@ pub(crate) fn decode_suffix( // Pack the leftovers from left to right. let morsel = decode_table[b as usize]; if morsel == INVALID_VALUE { - return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); + return Err(DecodeError::InvalidByte(input_index + leftover_index, b)); } morsels[morsels_in_leftover] = morsel; morsels_in_leftover += 1; } + // If there was 1 trailing byte, and it was valid, and we got to this point without hitting + // an invalid byte, now we can report invalid length + if !input.is_empty() && morsels_in_leftover < 2 { + return Err(DecodeError::InvalidLength( + input_index + morsels_in_leftover, + )); + } + match padding_mode { DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } DecodePaddingMode::RequireCanonical => { - if (padding_bytes + morsels_in_leftover) % 4 != 0 { + // allow empty input + if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { return Err(DecodeError::InvalidPadding); } } DecodePaddingMode::RequireNone => { - if padding_bytes > 0 { + if padding_bytes_count > 0 { // check at the end to make sure we let the cases of padding that should be InvalidByte // get hit return Err(DecodeError::InvalidPadding); @@ -120,27 +131,21 @@ pub(crate) fn decode_suffix( // useless since there are no more symbols to provide the necessary 4 additional bits // to finish the second original byte. - // TODO how do we know this? - debug_assert!(morsels_in_leftover != 1 && morsels_in_leftover != 5); let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; // Put the up to 6 complete bytes as the high bytes. // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. - let mut leftover_num = ((u64::from(morsels[0]) << 58) - | (u64::from(morsels[1]) << 52) - | (u64::from(morsels[2]) << 46) - | (u64::from(morsels[3]) << 40)) - | ((u64::from(morsels[4]) << 34) - | (u64::from(morsels[5]) << 28) - | (u64::from(morsels[6]) << 22) - | (u64::from(morsels[7]) << 16)); + let mut leftover_num = (u32::from(morsels[0]) << 26) + | (u32::from(morsels[1]) << 20) + | (u32::from(morsels[2]) << 14) + | (u32::from(morsels[3]) << 8); // if there are bits set outside the bits we care about, last symbol encodes trailing bits that // will not be included in the output - let mask = !0 >> (leftover_bytes_to_append * 8); + let mask = !0_u32 >> (leftover_bytes_to_append * 8); if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { // last morsel is at `morsels_in_leftover` - 1 return Err(DecodeError::InvalidLastSymbol( - start_of_leftovers + morsels_in_leftover - 1, + input_index + morsels_in_leftover - 1, last_symbol, )); } @@ -148,16 +153,17 @@ pub(crate) fn decode_suffix( // Strangely, this approach benchmarks better than writing bytes one at a time, // or copy_from_slice into output. for _ in 0..leftover_bytes_to_append { - let hi_byte = (leftover_num >> 56) as u8; + let hi_byte = (leftover_num >> 24) as u8; leftover_num <<= 8; + // TODO use checked writes output[output_index] = hi_byte; output_index += 1; } Ok(DecodeMetadata::new( output_index, - if padding_bytes > 0 { - Some(input_index + first_padding_index) + if padding_bytes_count > 0 { + Some(input_index + first_padding_offset) } else { None }, diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 6a50cbe..2546a6f 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -115,15 +115,12 @@ impl Engine for Naive { if estimate.rem == 1 { // trailing whitespace is so common that it's worth it to check the last byte to // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE - && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE - { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } + let last_byte = input[input.len() - 1]; + if last_byte != PAD_BYTE + && self.decode_table[usize::from(last_byte)] == general_purpose::INVALID_VALUE + { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); } - - return Err(DecodeError::InvalidLength); } let mut input_index = 0_usize; diff --git a/src/engine/tests.rs b/src/engine/tests.rs index b048005..b73f108 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -365,26 +365,49 @@ fn decode_detect_invalid_last_symbol(engine_wrapper: E) { } #[apply(all_engines)] -fn decode_detect_invalid_last_symbol_when_length_is_also_invalid( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // check across enough lengths that it would likely cover any implementation's various internal - // small/large input division +fn decode_detect_1_valid_symbol_in_last_quad_invalid_length(engine_wrapper: E) { for len in (0_usize..256).map(|len| len * 4 + 1) { - let engine = E::random_alphabet(&mut rng, &STANDARD); + for mode in all_pad_modes() { + let mut input = vec![b'A'; len]; - let mut input = vec![b'A'; len]; + let engine = E::standard_with_pad_mode(true, mode); - // with a valid last char, it's InvalidLength - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input)); - // after mangling the last char, it's InvalidByte - input[len - 1] = b'"'; - assert_eq!( - Err(DecodeError::InvalidByte(len - 1, b'"')), - engine.decode(&input) - ); + assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input)); + // if we add padding, then the first pad byte in the quad is invalid because it should + // be the second symbol + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(len, PAD_BYTE)), + engine.decode(&input) + ); + } + } + } +} + +#[apply(all_engines)] +fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte(engine_wrapper: E) { + for prefix_len in (0_usize..256).map(|len| len * 4) { + for mode in all_pad_modes() { + let mut input = vec![b'A'; prefix_len]; + input.push(b'*'); + + let engine = E::standard_with_pad_mode(true, mode); + + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + // adding padding doesn't matter + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + } + } } } @@ -471,8 +494,10 @@ fn decode_detect_invalid_last_symbol_every_possible_three_symbols(engine_wrapper: E) { /// Any amount of padding anywhere before the final non padding character = invalid byte at first /// pad byte. -/// From this, we know padding must extend to the end of the input. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_before_final_non_padding_char_error_invalid_byte( +/// From this and [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix_all_modes], +/// we know padding must extend contiguously to the end of the input. +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes< + E: EngineWrapper, +>( engine_wrapper: E, ) { - let mut rng = seeded_rng(); + // Different amounts of padding, w/ offset from end for the last non-padding char. + // Only canonical padding, so Canonical mode will work. + let suffixes = &[("AA==", 2), ("AAA=", 1), ("AAAA", 0)]; - // the different amounts of proper padding, w/ offset from end for the last non-padding char - let suffixes = [("/w==", 2), ("iYu=", 1), ("zzzz", 0)]; + for mode in pad_modes_allowing_padding() { + // We don't encode, so we don't care about encode padding. + let engine = E::standard_with_pad_mode(true, mode); - let prefix_quads_range = distributions::Uniform::from(0..=256); + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ); + } +} - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); +/// See [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes] +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix< + E: EngineWrapper, +>( + engine_wrapper: E, +) { + // Different amounts of padding, w/ offset from end for the last non-padding char, and + // non-canonical padding. + let suffixes = [ + ("AA==", 2), + ("AA=", 1), + ("AA", 0), + ("AAA=", 1), + ("AAA", 0), + ("AAAA", 0), + ]; - for _ in 0..100_000 { - for (suffix, offset) in suffixes.iter() { - let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng)); - s.push_str(suffix); - let mut encoded = s.into_bytes(); + // We don't encode, so we don't care about encode padding. + // Decoding is indifferent so that we don't get caught by missing padding on the last quad + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); + + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ) +} + +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine: impl Engine, + suffixes: &[(&str, usize)], +) { + let mut rng = seeded_rng(); - // calculate a range to write padding into that leaves at least one non padding char - let last_non_padding_offset = encoded.len() - 1 - offset; + let prefix_quads_range = distributions::Uniform::from(0..=256); - // don't include last non padding char as it must stay not padding - let padding_end = rng.gen_range(0..last_non_padding_offset); + for _ in 0..100_000 { + for (suffix, suffix_offset) in suffixes.iter() { + let mut s = "AAAA".repeat(prefix_quads_range.sample(&mut rng)); + s.push_str(suffix); + let mut encoded = s.into_bytes(); - // don't use more than 100 bytes of padding, but also use shorter lengths when - // padding_end is near the start of the encoded data to avoid biasing to padding - // the entire prefix on short lengths - let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); - let padding_start = padding_end.saturating_sub(padding_len); + // calculate a range to write padding into that leaves at least one non padding char + let last_non_padding_offset = encoded.len() - 1 - suffix_offset; - encoded[padding_start..=padding_end].fill(PAD_BYTE); + // don't include last non padding char as it must stay not padding + let padding_end = rng.gen_range(0..last_non_padding_offset); - assert_eq!( - Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), - engine.decode(&encoded), - ); - } + // don't use more than 100 bytes of padding, but also use shorter lengths when + // padding_end is near the start of the encoded data to avoid biasing to padding + // the entire prefix on short lengths + let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); + let padding_start = padding_end.saturating_sub(padding_len); + + encoded[padding_start..=padding_end].fill(PAD_BYTE); + + // should still have non-padding before any final padding + assert_ne!(PAD_BYTE, encoded[last_non_padding_offset]); + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + "len: {}, input: {}", + encoded.len(), + String::from_utf8(encoded).unwrap() + ); } } } -/// Any amount of padding before final chunk that crosses over into final chunk with 2-4 bytes = +/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes = /// invalid byte at first pad byte. -/// From this and [decode_padding_starts_before_final_chunk_error_invalid_length] we know the -/// padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_byte( +/// From this we know the padding must start in the final chunk. +#[apply(all_engines)] +fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad( engine_wrapper: E, ) { let mut rng = seeded_rng(); // must have at least one prefix quad let prefix_quads_range = distributions::Uniform::from(1..256); - // excluding 1 since we don't care about invalid length in this test - let suffix_pad_len_range = distributions::Uniform::from(2..=4); - for mode in all_pad_modes() { + let suffix_pad_len_range = distributions::Uniform::from(1..=4); + // don't use no-padding mode, as the reader decode might decode a block that ends with + // valid padding, which should then be referenced when encountering the later invalid byte + for mode in pad_modes_allowing_padding() { // we don't encode so we don't care about encode padding let engine = E::standard_with_pad_mode(true, mode); for _ in 0..100_000 { let suffix_len = suffix_pad_len_range.sample(&mut rng); - let mut encoded = "ABCD" + // all 0 bits so we don't hit InvalidLastSymbol with the reader decoder + let mut encoded = "AAAA" .repeat(prefix_quads_range.sample(&mut rng)) .into_bytes(); encoded.resize(encoded.len() + suffix_len, PAD_BYTE); @@ -705,40 +774,6 @@ fn decode_padding_starts_before_final_chunk_error_invalid_byte } } -/// Any amount of padding before final chunk that crosses over into final chunk with 1 byte = -/// invalid length. -/// From this we know the padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_length( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // must have at least one prefix quad - let prefix_quads_range = distributions::Uniform::from(1..256); - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - for _ in 0..100_000 { - let mut encoded = "ABCD" - .repeat(prefix_quads_range.sample(&mut rng)) - .into_bytes(); - encoded.resize(encoded.len() + 1, PAD_BYTE); - - // amount of padding must be long enough to extend back from suffix into previous - // quads - let padding_len = rng.gen_range(1 + 1..encoded.len()); - // no non-padding after padding in this test, so padding goes to the end - let padding_start = encoded.len() - padding_len; - encoded[padding_start..].fill(PAD_BYTE); - - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } - } -} - /// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding /// is not valid data (consistent with error for pad bytes in earlier chunks). /// From this we know there must be 2-3 bytes of data before padding @@ -756,29 +791,22 @@ fn decode_too_little_data_before_padding_error_invalid_byte(en let suffix_data_len = suffix_data_len_range.sample(&mut rng); let prefix_quad_len = prefix_quads_range.sample(&mut rng); - // ensure there is a suffix quad - let min_padding = usize::from(suffix_data_len == 0); - // for all possible padding lengths - for padding_len in min_padding..=(4 - suffix_data_len) { + for padding_len in 1..=(4 - suffix_data_len) { let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes(); encoded.resize(encoded.len() + suffix_data_len, b'A'); encoded.resize(encoded.len() + padding_len, PAD_BYTE); - if suffix_data_len + padding_len == 1 { - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } else { - assert_eq!( - Err(DecodeError::InvalidByte( - prefix_quad_len * 4 + suffix_data_len, - PAD_BYTE, - )), - engine.decode(&encoded), - "suffix data len {} pad len {}", - suffix_data_len, - padding_len - ); - } + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.decode(&encoded), + "suffix data len {} pad len {}", + suffix_data_len, + padding_len + ); } } } @@ -918,258 +946,64 @@ fn decode_pad_mode_indifferent_padding_accepts_anything(engine ); } -//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3 -// DecoderReader pseudo-engine finds the first padding, but doesn't report it as an error, -// because in the next decode it finds more padding, which is reported as InvalidByte, just -// with an offset at its position in the second decode, rather than being linked to the start -// of the padding that was first seen in the previous decode. -#[apply(all_engines_except_decoder_reader)] -fn decode_pad_byte_in_penultimate_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave room for at least one pad byte in penultimate quad - for num_valid_bytes_penultimate_quad in 0..4 { - // can't have 1 or it would be invalid length - for num_pad_bytes_in_final_quad in 2..=4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // varying amounts of padding in the penultimate quad - for _ in 0..num_valid_bytes_penultimate_quad { - s.push('A'); - } - // finish penultimate quad with padding - for _ in num_valid_bytes_penultimate_quad..4 { - s.push('='); - } - // and more padding in the final quad - for _ in 0..num_pad_bytes_in_final_quad { - s.push('='); - } - - // padding should be an invalid byte before the final quad. - // Could argue that the *next* padding byte (in the next quad) is technically the first - // erroneous one, but reporting that accurately is more complex and probably nobody cares - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + num_valid_bytes_penultimate_quad, - b'=', - ), - engine.decode(&s).unwrap_err(), - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_bytes_after_padding_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave at least one byte in the quad for padding - for bytes_after_padding in 1..4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding - for _ in 0..(3 - bytes_after_padding) { - s.push('A'); - } - s.push('='); - for _ in 0..bytes_after_padding { - s.push('A'); - } - - // First (and only) padding byte is invalid. - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + (3 - bytes_after_padding), - b'=' - ), - engine.decode(&s).unwrap_err() - ); - } - } - } -} - -#[apply(all_engines)] -fn decode_absurd_pad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("==Y=Wx===pY=2U====="); - - // first padding byte - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_too_much_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // add enough padding to ensure that we'll hit all decode stages at the different lengths - for pad_bytes in 1..=64 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - - if pad_bytes % 4 == 1 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_followed_by_non_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - for pad_bytes in 0..=32 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - s.push('E'); - - if pad_bytes % 4 == 0 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_one_char_in_final_quad_with_padding_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("E="); - - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - // more padding doesn't change the error - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -#[apply(all_engines)] -fn decode_too_few_symbols_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // <2 is invalid - for final_quad_symbols in 0..2 { - for padding_symbols in 0..=(4 - final_quad_symbols) { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - for _ in 0..final_quad_symbols { - s.push('A'); - } - for _ in 0..padding_symbols { - s.push('='); - } - - match final_quad_symbols + padding_symbols { - 0 => continue, - 1 => { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } - _ => { - // error reported at first padding byte - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + final_quad_symbols, - b'=', - ), - engine.decode(&s).unwrap_err() - ); - } - } - } - } - } - } -} - +/// 1 trailing byte that's not padding is detected as invalid byte even though there's padding +/// in the middle of the input. This is essentially mandating the eager check for 1 trailing byte +/// to catch the \n suffix case. // DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_bytes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_all_pad_modes_invalid_byte(engine_wrapper: E) { for mode in all_pad_modes() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } #[apply(all_engines)] -fn decode_invalid_trailing_bytes_all_modes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_invalid_byte(engine_wrapper: E) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } +fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { + for last_byte in [b'*', b'\n'] { + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("Cg=="); + let mut input = s.into_bytes(); + input.push(last_byte); + + // The case of trailing newlines is common enough to warrant a test for a good error + // message. + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + 4, + last_byte + )), + engine.decode(&input), + "mode: {:?}, input: {}", + mode, + String::from_utf8(input).unwrap() + ); + } + } +} +/// When there's 1 trailing byte, but it's padding, it's only InvalidByte if there isn't padding +/// earlier. #[apply(all_engines)] -fn decode_invalid_trailing_padding_as_invalid_length(engine_wrapper: E) { +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte( + engine_wrapper: E, +) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } @@ -1177,48 +1011,36 @@ fn decode_invalid_trailing_padding_as_invalid_length(engine_wr // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_padding_as_invalid_length_all_modes( +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes( engine_wrapper: E, ) { for mode in all_pad_modes() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } - -#[apply(all_engines)] -fn decode_wrong_length_error(engine_wrapper: E) { - let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); - +fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + engine: impl Engine, + mode: DecodePaddingMode, +) { for num_prefix_quads in 0..256 { - // at least one token, otherwise it wouldn't be a final quad - for num_tokens_final_quad in 1..=4 { - for num_padding in 0..=(4 - num_tokens_final_quad) { - let mut s: String = "IIII".repeat(num_prefix_quads); - for _ in 0..num_tokens_final_quad { - s.push('g'); - } - for _ in 0..num_padding { - s.push('='); - } + for (suffix, pad_offset) in [("AA===", 2), ("AAA==", 3), ("AAAA=", 4)] { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str(suffix); - let res = engine.decode(&s); - if num_tokens_final_quad >= 2 { - assert!(res.is_ok()); - } else if num_tokens_final_quad == 1 && num_padding > 0 { - // = is invalid if it's too early - assert_eq!( - Err(DecodeError::InvalidByte( - num_prefix_quads * 4 + num_tokens_final_quad, - 61 - )), - res - ); - } else if num_padding > 2 { - assert_eq!(Err(DecodeError::InvalidPadding), res); - } else { - assert_eq!(Err(DecodeError::InvalidLength), res); - } - } + assert_eq!( + // pad after `g`, not the last one + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + pad_offset, + PAD_BYTE + )), + engine.decode(&s), + "mode: {:?}, input: {}", + mode, + s + ); } } } @@ -1248,14 +1070,23 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); decode_buf.resize(input_len, 0); - // decode into the non-empty buf let decode_bytes_written = engine .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) .unwrap(); - assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); + + // TODO + // same for checked variant + // decode_buf.clear(); + // decode_buf.resize(input_len, 0); + // // decode into the non-empty buf + // let decode_bytes_written = engine + // .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]) + // .unwrap(); + // assert_eq!(orig_data.len(), decode_bytes_written); + // assert_eq!(orig_data, decode_buf); } } @@ -1355,38 +1186,6 @@ fn estimate_via_u128_inflation(engine_wrapper: E) { }) } -fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==\n"); - - // The case of trailing newlines is common enough to warrant a test for a good error - // message. - assert_eq!( - Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} - -fn do_invalid_trailing_padding_as_invalid_length(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==="); - - assert_eq!( - Err(DecodeError::InvalidLength), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} - /// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. /// /// Vecs provided should be empty. diff --git a/src/read/decoder.rs b/src/read/decoder.rs index b656ae3..125eeab 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -35,37 +35,39 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> { /// Where b64 data is read from inner: R, - // Holds b64 data read from the delegate reader. + /// Holds b64 data read from the delegate reader. b64_buffer: [u8; BUF_SIZE], - // The start of the pending buffered data in b64_buffer. + /// The start of the pending buffered data in `b64_buffer`. b64_offset: usize, - // The amount of buffered b64 data. + /// The amount of buffered b64 data after `b64_offset` in `b64_len`. b64_len: usize, - // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a - // decoded chunk in to, we have to be able to hang on to a few decoded bytes. - // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to - // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest - // into here, which seems like a lot of complexity for 1 extra byte of storage. - decoded_buffer: [u8; DECODED_CHUNK_SIZE], - // index of start of decoded data + /// Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a + /// decoded chunk in to, we have to be able to hang on to a few decoded bytes. + /// Technically we only need to hold 2 bytes, but then we'd need a separate temporary buffer to + /// decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest + /// into here, which seems like a lot of complexity for 1 extra byte of storage. + decoded_chunk_buffer: [u8; DECODED_CHUNK_SIZE], + /// Index of start of decoded data in `decoded_chunk_buffer` decoded_offset: usize, - // length of decoded data + /// Length of decoded data after `decoded_offset` in `decoded_chunk_buffer` decoded_len: usize, - // used to provide accurate offsets in errors - total_b64_decoded: usize, - // offset of previously seen padding, if any + /// Input length consumed so far. + /// Used to provide accurate offsets in errors + input_consumed_len: usize, + /// offset of previously seen padding, if any padding_offset: Option, } +// exclude b64_buffer as it's uselessly large impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("DecoderReader") .field("b64_offset", &self.b64_offset) .field("b64_len", &self.b64_len) - .field("decoded_buffer", &self.decoded_buffer) + .field("decoded_chunk_buffer", &self.decoded_chunk_buffer) .field("decoded_offset", &self.decoded_offset) .field("decoded_len", &self.decoded_len) - .field("total_b64_decoded", &self.total_b64_decoded) + .field("input_consumed_len", &self.input_consumed_len) .field("padding_offset", &self.padding_offset) .finish() } @@ -80,10 +82,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { b64_buffer: [0; BUF_SIZE], b64_offset: 0, b64_len: 0, - decoded_buffer: [0; DECODED_CHUNK_SIZE], + decoded_chunk_buffer: [0; DECODED_CHUNK_SIZE], decoded_offset: 0, decoded_len: 0, - total_b64_decoded: 0, + input_consumed_len: 0, padding_offset: None, } } @@ -100,7 +102,7 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { debug_assert!(copy_len <= self.decoded_len); buf[..copy_len].copy_from_slice( - &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len], + &self.decoded_chunk_buffer[self.decoded_offset..self.decoded_offset + copy_len], ); self.decoded_offset += copy_len; @@ -146,18 +148,22 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { ) .map_err(|e| match e { DecodeError::InvalidByte(offset, byte) => { - // This can be incorrect, but not in a way that probably matters to anyone: - // if there was padding handled in a previous decode, and we are now getting - // InvalidByte due to more padding, we should arguably report InvalidByte with - // PAD_BYTE at the original padding position (`self.padding_offset`), but we - // don't have a good way to tie those two cases together, so instead we - // just report the invalid byte as if the previous padding, and its possibly - // related downgrade to a now invalid byte, didn't happen. - DecodeError::InvalidByte(self.total_b64_decoded + offset, byte) + match (byte, self.padding_offset) { + // if there was padding in a previous block of decoding that happened to + // be correct, and we now find more padding that happens to be incorrect, + // to be consistent with non-reader decodes, record the error at the first + // padding + (PAD_BYTE, Some(first_pad_offset)) => { + DecodeError::InvalidByte(first_pad_offset, PAD_BYTE) + } + _ => DecodeError::InvalidByte(self.input_consumed_len + offset, byte), + } + } + DecodeError::InvalidLength(len) => { + DecodeError::InvalidLength(self.input_consumed_len + len) } - DecodeError::InvalidLength => DecodeError::InvalidLength, DecodeError::InvalidLastSymbol(offset, byte) => { - DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte) + DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte) } DecodeError::InvalidPadding => DecodeError::InvalidPadding, }) @@ -176,8 +182,8 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { self.padding_offset = self.padding_offset.or(decode_metadata .padding_offset - .map(|offset| self.total_b64_decoded + offset)); - self.total_b64_decoded += b64_len_to_decode; + .map(|offset| self.input_consumed_len + offset)); + self.input_consumed_len += b64_len_to_decode; self.b64_offset += b64_len_to_decode; self.b64_len -= b64_len_to_decode; @@ -283,7 +289,7 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE); let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?; - self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); + self.decoded_chunk_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); self.decoded_offset = 0; self.decoded_len = decoded;